{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Narrowband Dataset\n",
    "This notebook showcases the Narrowband dataset.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Variables\n",
    "\n",
    "num_iq_samples_dataset = 4096 # 64^2\n",
    "fft_size = 64\n",
    "impairment_level = 0 # clean"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Narrowband Metadata\n",
    "In order to create a NewNarrowband dataset, you must define parameters in NarrowbandMetadata. This can be done either in code or inside a YAML file. Below we show how to do both. Look at `narrowband_example.yaml` for a sample YAML file.\n",
    "\n",
    "There are three required parameters: \n",
    "1. `num_iq_samples_dataset` -> how much IQ data per sample\n",
    "2. `fft_size` -> Size of FFT (number of bins) to be used in spectrogram.\n",
    "3. `impairment_level` -> what environment impairment to simulate.\n",
    "\n",
    "Additionally, there are several optional parameters that can be overriden."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 1: Instantiate NarrowbandMetadata object\n",
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "\n",
    "narrowband_metadata_1 = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset, # 64^2\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level, # clean\n",
    ")\n",
    "print(narrowband_metadata_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 2: Instantiate as a dictionary object\n",
    "\n",
    "narrowband_metadata_2 = dict(\n",
    "    num_iq_samples = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    ")\n",
    "print(narrowband_metadata_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 3: Instantiate from YAML file\n",
    "# see narrowband_example.yaml\n",
    "\n",
    "narrowband_metadata_3 = \"narrowband_example.yaml\"\n",
    "print(narrowband_metadata_3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Infinite Dataset\n",
    "To create an infinite dataset, simply instantiate a NewNarrowband object. Ensure the `num_samples_dataset` field inside the NarrowbandMetadata is set to None.\n",
    "\n",
    "Notes:\n",
    "* You will not be able to access previously generated samples.\n",
    "* You cannot save an infinite dataset to disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "from torchsig.datasets.narrowband import NewNarrowband\n",
    "\n",
    "narrowband_infinite_metadata = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_samples = None,\n",
    ")\n",
    "\n",
    "narrowband_infinite = NewNarrowband(\n",
    "    dataset_metadata = narrowband_infinite_metadata\n",
    ")\n",
    "# print(narrowband_infinite)\n",
    "\n",
    "for i in range(10):\n",
    "    data, metadata = narrowband_infinite[i]\n",
    "    if i == 5:\n",
    "        print(f\"IQ Data: {data.shape}\")\n",
    "        print(f\"Signal Metadata: {metadata}\")\n",
    "\n",
    "# should fail\n",
    "print()\n",
    "try:\n",
    "    narrowband_infinite[0]\n",
    "except Exception as E:\n",
    "    print(E)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finite Dataset\n",
    "To create a finite dataset, set `num_samples` = any positive integer number.\n",
    "\n",
    "Note: you will still not be able to access previously generated samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "from torchsig.datasets.narrowband import NewNarrowband\n",
    "\n",
    "\n",
    "narrowband_finite_metadata = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_samples = 10,\n",
    ")\n",
    "\n",
    "narrowband_finite = NewNarrowband(\n",
    "    dataset_metadata = narrowband_finite_metadata\n",
    ")\n",
    "# print(narrowband_finite)\n",
    "\n",
    "for i in range(11):\n",
    "    if i == 10: # should fail\n",
    "        try:\n",
    "            data, metadata = narrowband_finite[i]\n",
    "        except Exception as E:\n",
    "            print(E)\n",
    "    else:\n",
    "        data, metadata = narrowband_finite[i]\n",
    "        print(f\"sample {i} generated successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing Dataset to Disk\n",
    "In order to access previosuly generated examples, or save the finite dataset for later, use the `DatasetCreator`. Pass in the Dataset to be saved, where to write the dataset (root), and whether to overwrite any exisitng datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "from torchsig.datasets.narrowband import NewNarrowband\n",
    "from torchsig.utils.writer import DatasetCreator\n",
    "from torchsig.signals.signal_lists import TorchSigSignalLists\n",
    "\n",
    "root = \"./datasets/narrowband_example\"\n",
    "class_list = TorchSigSignalLists.all_signals\n",
    "num_samples = len(class_list) * 10\n",
    "seed = 775489032758493\n",
    "\n",
    "narrowband_finite_metadata = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_samples = num_samples,\n",
    "    class_list = class_list,\n",
    "    seed = seed\n",
    ")\n",
    "\n",
    "narrowband_finite = NewNarrowband(\n",
    "    dataset_metadata = narrowband_finite_metadata\n",
    ")\n",
    "\n",
    "dataset_creator = DatasetCreator(\n",
    "    dataset = narrowband_finite,\n",
    "    root = root,\n",
    "    # overwrite = True,\n",
    "    # increase the batch size and num_workers to speed up\n",
    "    # batch_size = 24,\n",
    "    # num_workers = 24\n",
    ")\n",
    "\n",
    "dataset_creator.create()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reading Dataset from Disk\n",
    "Assuming you wrote a dataset to disk, you can load it back in my instantiating a `StaticNarrowband`.\n",
    "\n",
    "Warning: The following code assumes you have the code in the section **Writing Dataset to Disk**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.narrowband import StaticNarrowband\n",
    "\n",
    "static_narrowband = StaticNarrowband(\n",
    "    root = root,\n",
    "    impaired = impairment_level > 0,\n",
    ")\n",
    "\n",
    "# can access any sample\n",
    "print(static_narrowband[0])\n",
    "print(static_narrowband[5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the dataset writtent to disk is raw (aka no Transforms or Target Transforms were applied to it before writing to disk), then you can define whatever transforms and target transforms inside the StaticNarrowband."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Dataset Raw: {static_narrowband.raw}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "from torchsig.datasets.narrowband import StaticNarrowband\n",
    "from torchsig.transforms.dataset_transforms import Spectrogram\n",
    "from torchsig.transforms.target_transforms import (\n",
    "    ClassName,\n",
    "    Start,\n",
    "    Stop,\n",
    "    LowerFreq,\n",
    "    UpperFreq,\n",
    "    SNR\n",
    ")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "target_transform = [\n",
    "    ClassName(),\n",
    "    Start(),\n",
    "    Stop(),\n",
    "    LowerFreq(),\n",
    "    UpperFreq(),\n",
    "    SNR()\n",
    "]\n",
    "\n",
    "static_narrowband = StaticNarrowband(\n",
    "    root = root,\n",
    "    impaired = impairment_level > 0,\n",
    "    transforms = Spectrogram(fft_size = fft_size),\n",
    "    target_transforms = target_transform,\n",
    ")\n",
    "sample_rate = static_narrowband.dataset_metadata.sample_rate\n",
    "noise_power_db = static_narrowband.dataset_metadata.noise_power_db\n",
    "\n",
    "num_show = 4\n",
    "num_cols = 2\n",
    "num_rows = num_show // num_cols\n",
    "\n",
    "fig = plt.figure(figsize=(18,12))\n",
    "\n",
    "xmin = 0\n",
    "xmax = 1\n",
    "ymin = -sample_rate / 2\n",
    "ymax = sample_rate / 2\n",
    "\n",
    "# fig.suptitle(f\"class: {classname}\", fontsize=16)\n",
    "# plt.ylabel(\"Frequency (Hz)\")\n",
    "# plt.xlabel(\"Time\")\n",
    "fig.tight_layout()\n",
    "\n",
    "for i in range(num_show):\n",
    "    data, targets = static_narrowband[i]\n",
    "    ax = fig.add_subplot(num_rows,num_cols,i + 1)\n",
    "\n",
    "    pos = ax.imshow(data,extent=[xmin,xmax,ymin,ymax],aspect='auto',cmap='Wistia',vmin=noise_power_db)\n",
    "    fig.colorbar(pos, ax=ax)\n",
    "\n",
    "    # for t in targets:\n",
    "    classname, start, stop, lower, upper, snr = targets\n",
    "\n",
    "\n",
    "    ax.plot([start,start],[lower,upper],'b',alpha=0.5)\n",
    "    ax.plot([stop, stop],[lower,upper],'b',alpha=0.5)\n",
    "    ax.plot([start,stop],[lower,lower],'b',alpha=0.5)\n",
    "    ax.plot([start,stop],[upper,upper],'b',alpha=0.5)\n",
    "    textDisplay = str(classname) + ', SNR = ' + str(snr) + ' dB'\n",
    "    ax.text(start,lower,textDisplay, bbox=dict(facecolor='w', alpha=0.5, linewidth=0))\n",
    "    ax.set_xlim([0,1])\n",
    "    ax.set_ylim([-sample_rate/2,sample_rate/2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Narrowband Dataset Statistics\n",
    "\n",
    "Below are some plots and statistics about the Narrowband dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "class_list = static_narrowband.dataset_metadata.class_list\n",
    "snr_list = []\n",
    "\n",
    "class_counter = {class_name: 0 for class_name in class_list}\n",
    "\n",
    "for sample in tqdm(static_narrowband, desc = \"Calculating Narrowband Stats\"):\n",
    "    data, targets = sample\n",
    "    classname, start, stop, lower, upper, snr = targets\n",
    "    snr_list.append(snr)\n",
    "    class_counter[classname] += 1\n",
    "\n",
    "class_counts = list(class_counter.values())\n",
    "class_names = list(class_counter.keys())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Distribution Pie Chart\n",
    "# by default, the class distribution is None aka uniform\n",
    "print(f\"Class Distribution Setting: {static_narrowband.dataset_metadata.class_distribution}\")\n",
    "plt.figure(figsize=(12, 12))\n",
    "plt.pie(class_counts, labels = class_names)\n",
    "plt.title(\"Class Distribution Pie Chart\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Distribution Bar Chart\n",
    "plt.figure(figsize=(11, 4))\n",
    "plt.bar(class_names, class_counts)\n",
    "plt.xticks(rotation=90)\n",
    "plt.title(\"Class Distribution Bar Chart\")\n",
    "plt.xlabel(\"Modulation Class Name\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SNR Distributions\n",
    "# For narrowband, the default min and max SNR is 0 and 50 db, respectively\n",
    "print(f\"Min SNR Setting: {static_narrowband.dataset_metadata.snr_db_min}\")\n",
    "print(f\"Max SNR Setting: {static_narrowband.dataset_metadata.snr_db_max}\")\n",
    "plt.figure(figsize=(11, 4))\n",
    "plt.hist(x=snr_list, bins=100)\n",
    "plt.title(\"SNR Distribution\")\n",
    "plt.xlabel(\"SNR Bins (dB)\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:percent"
  },
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
